#ifdef INFINI_USE_TVM
#include "core/kernel.h"
#include "cuda/cuda_runtime.h"
#include "ffi/ffi_embed.h"
#include "nnet/Visitor/AsTVMVisitor.h"
#include "nnet/Visitor/CheckOOBVisitor.h"
#include "nnet/Visitor/HashVisitor.h"
#include "nnet/Visitor/MergeMemboundMutator.h"
#include "nvrtc.h"
#include "operators/membound.h"
#include "operators/pooling.h"

namespace py = pybind11;

namespace infini {

class TVMRecordObj : public PerfRecordObj {
    // TODO: Add more attrs
  public:
    size_t logSize, ptxSize;
    std::string log, ptx;
    std::vector<int> invokeParams;
    std::string kernelName;
    HashType simplifiedExprHash;
};

using TVMRecord = Ref<TVMRecordObj>;

class MemboundTVMExtractSource : public Kernel {
  public:
    void compute(const Operator &_op, const PerfRecord &record,
                 const RuntimeObj *_context) const override {
        auto op = as<MemBoundObj>(_op);
        // auto context = dynamic_cast<const CudaRuntimeObj *>(_context);
        auto tvmRecord = std::dynamic_pointer_cast<TVMRecordObj>(record);

        // prepare for evaluation
        CUmodule module;
        CUfunction kernel;
        checkCUresult(cuModuleLoadDataEx(&module, tvmRecord->ptx.data(), 0,
                                         nullptr, nullptr));
        checkCUresult(cuModuleGetFunction(&kernel, module,
                                          tvmRecord->kernelName.c_str()));
        std::vector<void *> args;
        for (auto &&in : op->getInputs()) {
            args.push_back(in->getRawDataPtr<void *>());
        }
        args.push_back(op->getOutput()->getRawDataPtr<void *>());
        std::vector<void *> argsPtr;
        for (auto &arg : args) {
            argsPtr.push_back(&arg);
        }
        auto invokeParams = tvmRecord->invokeParams;

        // begin evaluation
        cuLaunchKernel(kernel, invokeParams[0], invokeParams[1],
                       invokeParams[2], invokeParams[3], invokeParams[4],
                       invokeParams[5], 0, NULL, argsPtr.data(), 0);

        // free module
        checkCUresult(cuModuleUnload(module));
    }

    void compute(const Operator &_op,
                 const RuntimeObj *_context) const override {
        IT_ASSERT(false, "A TVM record is required for membound kernel.");
    }

    std::string getVarName(const Tensor &t) const {
        return "var_" + std::to_string(t->getGuid());
    }

    bool checkOOB(nnet::Expr expr) const {
        return nnet::CheckOOBVisitor().checkRangeOp(
            nnet::as<nnet::RangeOpNode>(expr));
    }

    // Premise: op is idempotent since it is called multiple times.
    PerfRecord tune(const Operator &_op,
                    const RuntimeObj *_context) const override {
        TVMRecord ret = std::make_shared<TVMRecordObj>();
        auto op = as<MemBoundObj>(_op);
        auto context = dynamic_cast<const CudaRuntimeObj *>(_context);

        // invoke Ansor to tune a membound kernel
        nnet::AsTVMVisitor visitor;
        IT_ASSERT(!checkOOB(op->getNnetExpr()));
        // fuse stages in nnet expr to reduce kernels generated by TVM
        auto expr = op->getNnetExpr();
        if (auto mergedExpr =
                nnet::MergeMemboundMutator({expr}).merge(false, true))
            expr = mergedExpr;

        nnet::HashVisitor hashVisitor;
        HashType hashCode = hashVisitor.getHash(expr);

        visitor.dispatch(expr);
        auto &&stmts = visitor.getStmts();
        auto &&inShapes = visitor.getInputShapes();
        auto &&outShape = visitor.getOutputShape();

        std::vector<std::string> inputs;
        for (auto &&in : op->getInputs()) {
            inputs.emplace_back(getVarName(in));
        }
        const std::string output = getVarName(op->getOutput());

        const std::string func = "membound_" + std::to_string(hashCode);
        const std::string kernelName = func + "_kernel0";
        auto res = getAnsorCode(
            inShapes, std::vector<std::string>(inShapes.size(), "float32"),
            outShape, "float32", stmts, func, inputs, output, op->toString(),
            expr->toReadable(), hashCode);

        // compile the kernel
        auto funcCode = res.first;
        auto invokeParams = res.second;
        std::string fileName = func + ".cu";
        nvrtcProgram prog;
        nvrtcCreateProgram(&prog,            // prog
                           funcCode.c_str(), // buffer
                           fileName.c_str(), // name
                           0,                // numHeaders
                           NULL,             // headers
                           NULL);            // includeNames
        const char *opts[] = {"--gpu-architecture=compute_80", "--fmad=false"};
        nvrtcCompileProgram(prog,  // prog
                            2,     // numOptions
                            opts); // options

        // copy ptx and log to ret
        size_t logSize;
        nvrtcGetProgramLogSize(prog, &logSize);
        size_t ptxSize;
        nvrtcGetPTXSize(prog, &ptxSize);
        ret->logSize = logSize;
        ret->ptxSize = ptxSize;
        ret->log = std::string(logSize, ' ');
        ret->ptx = std::string(ptxSize, ' ');
        nvrtcGetProgramLog(prog, ret->log.data());
        nvrtcGetPTX(prog, ret->ptx.data());
        ret->invokeParams = invokeParams;
        ret->kernelName = kernelName;
        ret->simplifiedExprHash = hashCode;

        // prepare for evaluation
        CUmodule module;
        CUfunction kernel;
        checkCUresult(
            cuModuleLoadDataEx(&module, ret->ptx.data(), 0, nullptr, nullptr));
        checkCUresult(cuModuleGetFunction(&kernel, module, kernelName.c_str()));
        std::vector<void *> args;
        for (auto &&in : op->getInputs())
            args.push_back(in->getRawDataPtr<void *>());
        args.push_back(op->getOutput()->getRawDataPtr<void *>());
        std::vector<void *> argsPtr;
        for (auto &arg : args)
            argsPtr.push_back(&arg);

        // Evaluate the kernel
        ret->time = timeit(
            [&]() {
                cuLaunchKernel(kernel, invokeParams[0], invokeParams[1],
                               invokeParams[2], invokeParams[3],
                               invokeParams[4], invokeParams[5], 0, NULL,
                               argsPtr.data(), 0);
            },
            [&]() { context->sync(); });

        // free module
        checkCUresult(cuModuleUnload(module));
        nvrtcDestroyProgram(&prog);

        return std::dynamic_pointer_cast<PerfRecordObj>(ret);
    }

    /// @brief
    /// @param inDims
    /// @param inDTypes
    /// @param outDims
    /// @param outDType
    /// @param lambda
    /// @param funcName Generated function name
    /// @param inputNames Input array names in the generated invocation code.
    /// @param outputName Output array names in the generated invocation code.
    /// @param nnetExpressionString Save expr in string for logging.
    /// @param nnetSimplifiedExprString Save simplified expr in string for
    /// logging.
    /// @param hashCode (optional) Hash code of the input expression for kernel
    /// cache.
    /// @return
    std::pair<std::string, std::vector<int>>
    getAnsorCode(const std::vector<std::vector<int>> &inDims,
                 const std::vector<std::string> &inDTypes,
                 const std::vector<int> &outDims, const std::string &outDType,
                 const std::string &lambda, const std::string &funcName,
                 const std::vector<std::string> &inputNames,
                 const std::string &outputName,
                 const std::string &nnetExprString,
                 const std::string &nnetSimplifiedExprString,
                 const HashType hashCode) const {
        std::string funcCode;
        std::vector<int> invokeParams;
        try {
            start_interpreter();
            // Use static to avoid re-importing the module. Re-importing results
            // in cuBLAS failure, whose root cause is not identified yet.
            static auto func =
                py::module::import("cpp_plugin").attr("gen_ansor_op");
            py::tuple code =
                func(inDims, inDTypes, outDims, outDType, lambda, funcName,
                     inputNames, outputName, nnetExprString,
                     nnetSimplifiedExprString, std::to_string(hashCode));
            funcCode = py::str(code[0]);
            auto temp = py::list(code[3]);
            for (int i = 0; i < 6; ++i) {
                invokeParams.push_back(temp[i].cast<int>());
            }
        } catch (py::error_already_set &e) {
            if (e.matches(PyExc_ImportError)) {
                std::cerr << "Import Error. Don't forget to set environment "
                             "variable PYTHONPATH to contain "
                             "<repo-root>/python"
                          << std::endl;
            }
            throw;
        }
        return std::make_pair(funcCode, invokeParams);
    }
};

REGISTER_KERNEL(Device::CUDA, OpType::MemBound, MemboundTVMExtractSource,
                "Memobund_TVM_Ansor_extract_source");
}; // namespace infini

#endif
